#from PIL import Image
import matplotlib.image as mpimg
import numpy as np
import os
from pathlib import Path
from PIL import Image
import random

from demo2_mainbody import LegendreDecomposition1
from revise_lgd import LegendreDecomposition





dataset_dir = "/tmp/pycharm_project_272/coil_100_dataset/coil-100"

seed=random.seed(388)
selected_images = []
all_selected_images = []


for file_name in os.listdir(dataset_dir):
    if file_name.endswith((".png", ".jpg", ".jpeg", ".bmp")):  
        img_path = os.path.join(dataset_dir, file_name)
        img = mpimg.imread(img_path)


        img_array = np.array(img)


        if img_array.shape == (128, 128, 3) and img_array.min() > 0:
            all_selected_images.append(img_array)


if len(all_selected_images) >= 2:
    selected_images = random.sample(all_selected_images, 2)
    print(f"Randomly selected 11 images from {len(all_selected_images)} valid images.")
else:
    print(f"Not enough valid images. Found only {len(all_selected_images)} images.")


if len(all_selected_images) >= 2:
    print(f"Selected images: {len(selected_images)}")
else:
    print("Unable to select 11 images.")

print(f"Total images with all elements greater than 0: {len(selected_images)}")


if len(selected_images) == 2:

    image_data = np.stack(selected_images, axis=-1)


    print(f"Tensor shape before resizing: {image_data.shape}")
    max_value = image_data.max()
    min_value = image_data.min()

    print(f"Max value in the first selected image: {max_value}")
    print(f"Min value in the first selected image: {min_value}")





sizes = list(range(4, 129, 24))
#sizes = [43]
print(sizes)

def out_put_coordinates(tensor):
    indices = np.argwhere(tensor == 1)
    indices1 = np.argwhere(tensor == 0)


    coordinates = [tuple(index) for index in indices]


    coordinates1 = coordinates.copy()
    if len(indices1) >= 2:
        coordinates1.append(tuple(indices1[1]))

    coordinates2 = [tuple(index1) for index1 in indices1]

    del coordinates2[0]
    return coordinates, coordinates1, coordinates2


def change_parameter(tensor, k):
    dims = tensor.shape
    binary_tensor = np.ones(dims, dtype=int)


    if k >= len(dims):
        raise ValueError("k must be strictly less than the tensor's dimension.")


    binary_tensor[(0,) * len(dims)] = 0


    it = np.nditer(binary_tensor, flags=['multi_index'])
    while not it.finished:
        index = it.multi_index

        non_zero_count = sum(1 for i in index if i != 0)


        if non_zero_count > k or non_zero_count == 0:
            binary_tensor[index] = 0
        it.iternext()

    return binary_tensor

def calculate_s(P):
    # Step 1: Calculate size_P as the product of all dimensions of P
    size_P = np.prod(P.shape)
    # Step 2: Calculate the sum of all elements in the tensor
    sum_P = np.sum(P)

    # Step 3: Find the minimum element in the tensor
    min_P = np.min(P)

    # Step 4: Calculate s using the formula
    s = np.log(sum_P / min_P) / np.log(size_P)

    return s

'''
def is_c_in_range(P, s, c):
    tensor_shape = P.shape
    d = len(tensor_shape)
    I_prod = np.prod(tensor_shape)
    log_I_prod = np.log(I_prod)

    max_lower_bound = float('-inf')
    min_upper_bound = float('inf')

    for i in range(d):
        I_i = tensor_shape[i]
        prod_except_i = np.prod([tensor_shape[j] + 1 for j in range(d) if j != i])

        lower_bound = -(2 ** (d - 1) * ((s - 1) * d + 1) * I_i * log_I_prod) / prod_except_i
        upper_bound = (2 ** (d - 1) * ((s - 1) * d + 1) * I_i * log_I_prod) / prod_except_i

        max_lower_bound = max(max_lower_bound, lower_bound)
        min_upper_bound = min(min_upper_bound, upper_bound)

    # Check if c is within the range
    if max_lower_bound <= c <= min_upper_bound:
        print(f'c = {c} is within the range: [{max_lower_bound}, {min_upper_bound}]')
        return min_upper_bound,1
    else:
        print(f'c = {c} is outside the range: [{max_lower_bound}, {min_upper_bound}]')
        return min_upper_bound,0
'''

def is_c_in_range(P, s, c, alpha):

    dims = P.shape
    d = len(dims)


    size_P = np.prod(dims)




    log_product_Ij = np.log(size_P)


    left_bounds = []
    right_bounds = []

    for i in range(d):
        I_i = dims[i]


        prod_dims_plus_1_excluding_i = np.prod([(dims[j] + 1) for j in range(d) if j != i])


        left_bound = (-2 **d * ((s - 1) * d + 1) * log_product_Ij) / (
                    ((1 - 1 / alpha) * I_i + 1) * prod_dims_plus_1_excluding_i)
        left_bounds.append(left_bound)


        right_bound = (2 **d * ((s - 1) *d + 1) * log_product_Ij) / (
                    ((1 - 1 / alpha) * I_i + 1) * prod_dims_plus_1_excluding_i)
        right_bounds.append(right_bound)


    max_lower_bound = max(left_bounds)
    min_upper_bound = min(right_bounds)




    if max_lower_bound <= c <= min_upper_bound:
       print(f'c = {c} is within the range: [{max_lower_bound}, {min_upper_bound}]')
       return min_upper_bound, 1
    else:
       print(f'c = {c} is outside the range: [{max_lower_bound}, {min_upper_bound}]')
       return min_upper_bound, 0









for k in range(1, 4):
 results_c = []
 results_upperbound = []
 rights_line = []
 results_sumsize = []
 for size in sizes:

    P = image_data[:size, :size, :3,:2]
    #P = image_data[:size, :size, :3, :banch]
    print('size=',size)

    sumsize=size*size*3*2
    print('sumsize=', sumsize)

    results_sumsize.append(sumsize)

    binary_tensor = change_parameter(P, k)
    #print('binary_tensor=', binary_tensor)
    coordinates, coordinates1, coordinates_complement = out_put_coordinates(binary_tensor)
    #print('coordinates=', coordinates)
    ld_ori = LegendreDecomposition(solver='ng', max_iter=50000, verbose=0, learning_rate=0.001)
    reconst_tensor_ori = ld_ori.fit_transform(P, coordinates)

    print('Reconstruction error(RSE): {}'.format(ld_ori.reconstruction_err_))

    ld_imp = LegendreDecomposition1(solver='ng', max_iter=500, verbose=0, learning_rate=0.001)
    reconst_tensor_imp = ld_imp.fit_transform(P, coordinates, coordinates1, coordinates_complement, ld_ori.theta)

    print('Reconstruction error(RSE): {}'.format(ld_imp.reconstruction_err_))
    results_c.append(ld_imp.c)
    print('c=', ld_imp.c)
    s=calculate_s(P)
    print('s=',s)
    upperbounds, rights = is_c_in_range(P, s, ld_imp.c, 2)
    results_upperbound.append(upperbounds)
    rights_line.append(rights)

    folder = Path(f'size_test/bound_coil_{k}body_seed_388')
    folder.mkdir(parents=True, exist_ok=True)
    with open(folder / 'results_upperbound.txt', 'w') as file:
     for result in results_upperbound:
        file.write(str(result) + '\n')

    with open(folder / 'rights_line_append.txt', 'w') as file:
     for result in rights_line:
        file.write(str(result) + '\n')

    with open(folder / 'results_c.txt', 'w') as file:
     for result in results_c:
        file.write(str(result) + '\n')

    with open(folder / 'results_sizes.txt', 'w') as file:
     for result in results_sumsize:
        file.write(str(result) + '\n')
